{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "reinforcement_pm.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "toc_visible": true,
      "authorship_tag": "ABX9TyP/iy3Nqhs0Rruzh0H51VL9",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/letianzj/QuantResearch/blob/master/ml/reinforcement_pm.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Reinforcement Portfolio Manager"
      ],
      "metadata": {
        "id": "EeHaU9gwuOHg"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Setup\n",
        "\n",
        "Uncomment to execute once"
      ],
      "metadata": {
        "id": "XfAGpFBxuSxX"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# !sudo apt-get update\n",
        "# !pip install yfinance\n",
        "# !pip install ta\n",
        "# !pip install -U gym==0.21.0\n",
        "# !pip install -U quanttrader==0.5.5\n",
        "# !pip install -U pyfolio==0.9.2\n",
        "\n",
        "# !sudo apt-get install -y xvfb ffmpeg freeglut3-dev\n",
        "# !pip install 'imageio==2.4.0'\n",
        "# !pip install pyvirtualdisplay\n",
        "# !pip install tf-agents[reverb]\n",
        "# !pip install pyglet\n",
        "# !pip install -U PyYaml==3.13"
      ],
      "metadata": {
        "id": "gSR89jOVuRWv"
      },
      "execution_count": 7,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Restart the runtime to take PyYaml==3.13 into effect. Otherwise pyfolio will complain on yaml.load error.\n",
        "\n",
        "Code below might need to run twice."
      ],
      "metadata": {
        "id": "z9CV3FXvud3l"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "import io\n",
        "import tempfile\n",
        "import shutil\n",
        "import zipfile\n",
        "from google.colab import files\n",
        "\n",
        "from datetime import datetime\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "import yfinance as yf\n",
        "import gym\n",
        "import quanttrader as qt\n",
        "from quanttrader import PortfolioEnv\n",
        "import pyfolio as pf\n",
        "\n",
        "import tensorflow as tf\n",
        "from tf_agents.agents.dqn import dqn_agent\n",
        "from tf_agents.drivers import py_driver\n",
        "from tf_agents.drivers.dynamic_step_driver import DynamicStepDriver\n",
        "from tf_agents.environments import tf_py_environment\n",
        "from tf_agents.environments import suite_gym\n",
        "from tf_agents.eval import metric_utils\n",
        "from tf_agents.metrics import tf_metrics\n",
        "from tf_agents.networks import sequential, q_network, network\n",
        "from tf_agents.policies import py_tf_eager_policy\n",
        "from tf_agents.policies import random_tf_policy\n",
        "from tf_agents.policies import policy_saver\n",
        "from tf_agents.replay_buffers import TFUniformReplayBuffer\n",
        "from tf_agents.trajectories import trajectory\n",
        "from tf_agents.specs import tensor_spec\n",
        "from tf_agents.utils import common\n",
        "\n",
        "import base64\n",
        "import imageio\n",
        "import IPython\n",
        "import matplotlib\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import PIL.Image\n",
        "import pyvirtualdisplay\n",
        "import reverb"
      ],
      "metadata": {
        "id": "-uCnL8iXuZyI"
      },
      "execution_count": 29,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "gym.__version__, qt.__version__, pf.__version__"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "69QwFPETuhPv",
        "outputId": "4db23bad-4dea-4a49-91df-51d7f697b1f1"
      },
      "execution_count": 30,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "('0.21.0', '0.5.5', '0.9.2')"
            ]
          },
          "metadata": {},
          "execution_count": 30
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "def load_data():\n",
        "    from datetime import timedelta\n",
        "    import ta\n",
        "\n",
        "    start_date = datetime(2010, 1, 1)\n",
        "    end_date = datetime(2020, 12, 31)\n",
        "    syms = ['SPY', 'QQQ']\n",
        "    max_price_scaler = 5_000.0\n",
        "    max_price_scaler = 1\n",
        "    max_volume_scaler = 1.5e8\n",
        "    df_obs = pd.DataFrame()             # observation\n",
        "    df_exch = pd.DataFrame()            # exchange; for order match\n",
        "\n",
        "    for sym in syms:\n",
        "        df = yf.download(sym, start=start_date, end=end_date)\n",
        "        df.index = pd.to_datetime(df.index) + timedelta(hours=15, minutes=59, seconds=59)\n",
        "\n",
        "        df_exch = pd.concat([df_exch, df['Close'].rename(sym)], axis=1)\n",
        "\n",
        "        df['Open'] = df['Adj Close'] / df['Close'] * df['Open'] / max_price_scaler\n",
        "        df['High'] = df['Adj Close'] / df['Close'] * df['High'] / max_price_scaler\n",
        "        df['Low'] = df['Adj Close'] / df['Close'] * df['Low'] / max_price_scaler\n",
        "        df['Volume'] = df['Adj Close'] / df['Close'] * df['Volume'] / max_volume_scaler\n",
        "        df['Close'] = df['Adj Close'] / max_price_scaler\n",
        "        df = df[['Open', 'High', 'Low', 'Close', 'Volume']]\n",
        "        df.columns = [f'{sym}:{c.lower()}' for c in df.columns]\n",
        "\n",
        "        macd = ta.trend.MACD(close=df[f'{sym}:close'])\n",
        "        df[f'{sym}:macd'] = macd.macd()\n",
        "        df[f'{sym}:macd_diff'] = macd.macd_diff()\n",
        "        df[f'{sym}:macd_signal'] = macd.macd_signal()\n",
        "\n",
        "        rsi = ta.momentum.RSIIndicator(close=df[f'{sym}:close'])\n",
        "        df[f'{sym}:rsi'] = rsi.rsi()\n",
        "\n",
        "        bb = ta.volatility.BollingerBands(close=df[f'{sym}:close'], window=20, window_dev=2)\n",
        "        df[f'{sym}:bb_bbm'] = bb.bollinger_mavg()\n",
        "        df[f'{sym}:bb_bbh'] = bb.bollinger_hband()\n",
        "        df[f'{sym}:bb_bbl'] = bb.bollinger_lband()\n",
        "\n",
        "        atr = ta.volatility.AverageTrueRange(high=df[f'{sym}:high'], low=df[f'{sym}:low'], close=df[f'{sym}:close'])\n",
        "        df[f'{sym}:atr'] = atr.average_true_range()\n",
        "\n",
        "        df_obs = pd.concat([df_obs, df], axis=1)\n",
        "\n",
        "    return df_obs, df_exch"
      ],
      "metadata": {
        "id": "ADPZhpzxunLP"
      },
      "execution_count": 8,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "df_obs, df_exch = load_data()"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "48Yn4frduqNX",
        "outputId": "932508c9-5130-4cbf-c9b3-03a64443b39c"
      },
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[*********************100%***********************]  1 of 1 completed\n",
            "[*********************100%***********************]  1 of 1 completed\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Trading Environment"
      ],
      "metadata": {
        "id": "xA5Ozb6Tu3F2"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "look_back = 10\n",
        "cash = 100_000.0\n",
        "max_nav_scaler = cash\n",
        "\n",
        "train_qt_env = PortfolioEnv(df_obs, df_exch)\n",
        "train_qt_env.set_cash(cash)\n",
        "train_qt_env.set_commission(0.0001)\n",
        "train_qt_env.set_steps(n_lookback=10, n_warmup=50, n_maxsteps=250)\n",
        "train_qt_env.set_feature_scaling(max_nav_scaler)\n",
        "\n",
        "eval_qt_env = PortfolioEnv(df_obs, df_exch)\n",
        "eval_qt_env.set_cash(cash)\n",
        "eval_qt_env.set_commission(0.0001)\n",
        "eval_qt_env.set_steps(n_lookback=10, n_warmup=50, n_maxsteps=2000, n_init_step=504)         # index 504 is 2012-01-03\n",
        "eval_qt_env.set_feature_scaling(max_nav_scaler)"
      ],
      "metadata": {
        "id": "wN8XGfWDu1EO"
      },
      "execution_count": 15,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Take one step to see how the environment works."
      ],
      "metadata": {
        "id": "4C0SjxpGwrzu"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "o1 = eval_qt_env.reset()\n",
        "action = np.array([0.4, 0.4], dtype=np.float64)\n",
        "o2, reward, done, info = eval_qt_env.step(action)"
      ],
      "metadata": {
        "id": "y5MO5Z39vPUu"
      },
      "execution_count": 19,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "There are two stocks, each has 13 features; plus NAV, resulting in total 27 features.\n",
        "\n",
        "The lookback window is 10 days or two weeks."
      ],
      "metadata": {
        "id": "MLpL8n1Sw1Qu"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "o1.shape, o2.shape"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ivyPb3HjvjMO",
        "outputId": "8b277d52-fb32-4164-b087-de40a0a4cfe8"
      },
      "execution_count": 20,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "((10, 27), (10, 27))"
            ]
          },
          "metadata": {},
          "execution_count": 20
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "idx0 = eval_qt_env._init_step\n",
        "idx1 = idx0+3\n",
        "eval_qt_env._df_obs_scaled[idx0:idx1]         # observation"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 343
        },
        "id": "03DS7pcNwT52",
        "outputId": "5cc577cc-cd87-4980-908e-bfffe5cf2b10"
      },
      "execution_count": 21,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/html": [
              "\n",
              "  <div id=\"df-81dee839-195d-48ad-be85-2c4b3e5311c3\">\n",
              "    <div class=\"colab-df-container\">\n",
              "      <div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>SPY:open</th>\n",
              "      <th>SPY:high</th>\n",
              "      <th>SPY:low</th>\n",
              "      <th>SPY:close</th>\n",
              "      <th>SPY:volume</th>\n",
              "      <th>SPY:macd</th>\n",
              "      <th>SPY:macd_diff</th>\n",
              "      <th>SPY:macd_signal</th>\n",
              "      <th>SPY:rsi</th>\n",
              "      <th>SPY:bb_bbm</th>\n",
              "      <th>SPY:bb_bbh</th>\n",
              "      <th>SPY:bb_bbl</th>\n",
              "      <th>SPY:atr</th>\n",
              "      <th>QQQ:open</th>\n",
              "      <th>QQQ:high</th>\n",
              "      <th>QQQ:low</th>\n",
              "      <th>QQQ:close</th>\n",
              "      <th>QQQ:volume</th>\n",
              "      <th>QQQ:macd</th>\n",
              "      <th>QQQ:macd_diff</th>\n",
              "      <th>QQQ:macd_signal</th>\n",
              "      <th>QQQ:rsi</th>\n",
              "      <th>QQQ:bb_bbm</th>\n",
              "      <th>QQQ:bb_bbh</th>\n",
              "      <th>QQQ:bb_bbl</th>\n",
              "      <th>QQQ:atr</th>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>Date</th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>2012-01-03 15:59:59</th>\n",
              "      <td>105.490874</td>\n",
              "      <td>106.002808</td>\n",
              "      <td>105.218393</td>\n",
              "      <td>105.276192</td>\n",
              "      <td>1.066237</td>\n",
              "      <td>0.885565</td>\n",
              "      <td>0.284873</td>\n",
              "      <td>0.600691</td>\n",
              "      <td>60.324460</td>\n",
              "      <td>102.609565</td>\n",
              "      <td>105.926763</td>\n",
              "      <td>99.292368</td>\n",
              "      <td>1.674594</td>\n",
              "      <td>51.671302</td>\n",
              "      <td>51.925526</td>\n",
              "      <td>51.526030</td>\n",
              "      <td>51.662224</td>\n",
              "      <td>0.239178</td>\n",
              "      <td>0.060246</td>\n",
              "      <td>0.133801</td>\n",
              "      <td>-0.073555</td>\n",
              "      <td>56.593513</td>\n",
              "      <td>50.688891</td>\n",
              "      <td>52.166454</td>\n",
              "      <td>49.211328</td>\n",
              "      <td>0.876583</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2012-01-04 15:59:59</th>\n",
              "      <td>105.028497</td>\n",
              "      <td>105.532172</td>\n",
              "      <td>104.623908</td>\n",
              "      <td>105.441345</td>\n",
              "      <td>0.700116</td>\n",
              "      <td>0.992830</td>\n",
              "      <td>0.313711</td>\n",
              "      <td>0.679119</td>\n",
              "      <td>60.780561</td>\n",
              "      <td>102.703508</td>\n",
              "      <td>106.223549</td>\n",
              "      <td>99.183468</td>\n",
              "      <td>1.619856</td>\n",
              "      <td>51.580514</td>\n",
              "      <td>51.952772</td>\n",
              "      <td>51.353527</td>\n",
              "      <td>51.880135</td>\n",
              "      <td>0.177978</td>\n",
              "      <td>0.147476</td>\n",
              "      <td>0.176824</td>\n",
              "      <td>-0.029349</td>\n",
              "      <td>57.817202</td>\n",
              "      <td>50.694258</td>\n",
              "      <td>52.188220</td>\n",
              "      <td>49.200296</td>\n",
              "      <td>0.856773</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2012-01-05 15:59:59</th>\n",
              "      <td>104.871587</td>\n",
              "      <td>105.878930</td>\n",
              "      <td>104.392682</td>\n",
              "      <td>105.722046</td>\n",
              "      <td>0.957229</td>\n",
              "      <td>1.087947</td>\n",
              "      <td>0.327062</td>\n",
              "      <td>0.760885</td>\n",
              "      <td>61.588795</td>\n",
              "      <td>102.809845</td>\n",
              "      <td>106.552665</td>\n",
              "      <td>99.067025</td>\n",
              "      <td>1.610313</td>\n",
              "      <td>51.771184</td>\n",
              "      <td>52.352270</td>\n",
              "      <td>51.571434</td>\n",
              "      <td>52.306873</td>\n",
              "      <td>0.249750</td>\n",
              "      <td>0.248179</td>\n",
              "      <td>0.222022</td>\n",
              "      <td>0.026157</td>\n",
              "      <td>60.184417</td>\n",
              "      <td>50.728198</td>\n",
              "      <td>52.332262</td>\n",
              "      <td>49.124135</td>\n",
              "      <td>0.851349</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>\n",
              "      <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-81dee839-195d-48ad-be85-2c4b3e5311c3')\"\n",
              "              title=\"Convert this dataframe to an interactive table.\"\n",
              "              style=\"display:none;\">\n",
              "        \n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "       width=\"24px\">\n",
              "    <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
              "    <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
              "  </svg>\n",
              "      </button>\n",
              "      \n",
              "  <style>\n",
              "    .colab-df-container {\n",
              "      display:flex;\n",
              "      flex-wrap:wrap;\n",
              "      gap: 12px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert {\n",
              "      background-color: #E8F0FE;\n",
              "      border: none;\n",
              "      border-radius: 50%;\n",
              "      cursor: pointer;\n",
              "      display: none;\n",
              "      fill: #1967D2;\n",
              "      height: 32px;\n",
              "      padding: 0 0 0 0;\n",
              "      width: 32px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert:hover {\n",
              "      background-color: #E2EBFA;\n",
              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "      fill: #174EA6;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert {\n",
              "      background-color: #3B4455;\n",
              "      fill: #D2E3FC;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert:hover {\n",
              "      background-color: #434B5C;\n",
              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "      fill: #FFFFFF;\n",
              "    }\n",
              "  </style>\n",
              "\n",
              "      <script>\n",
              "        const buttonEl =\n",
              "          document.querySelector('#df-81dee839-195d-48ad-be85-2c4b3e5311c3 button.colab-df-convert');\n",
              "        buttonEl.style.display =\n",
              "          google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "        async function convertToInteractive(key) {\n",
              "          const element = document.querySelector('#df-81dee839-195d-48ad-be85-2c4b3e5311c3');\n",
              "          const dataTable =\n",
              "            await google.colab.kernel.invokeFunction('convertToInteractive',\n",
              "                                                     [key], {});\n",
              "          if (!dataTable) return;\n",
              "\n",
              "          const docLinkHtml = 'Like what you see? Visit the ' +\n",
              "            '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
              "            + ' to learn more about interactive tables.';\n",
              "          element.innerHTML = '';\n",
              "          dataTable['output_type'] = 'display_data';\n",
              "          await google.colab.output.renderOutput(dataTable, element);\n",
              "          const docLink = document.createElement('div');\n",
              "          docLink.innerHTML = docLinkHtml;\n",
              "          element.appendChild(docLink);\n",
              "        }\n",
              "      </script>\n",
              "    </div>\n",
              "  </div>\n",
              "  "
            ],
            "text/plain": [
              "                       SPY:open    SPY:high  ...  QQQ:bb_bbl   QQQ:atr\n",
              "Date                                         ...                      \n",
              "2012-01-03 15:59:59  105.490874  106.002808  ...   49.211328  0.876583\n",
              "2012-01-04 15:59:59  105.028497  105.532172  ...   49.200296  0.856773\n",
              "2012-01-05 15:59:59  104.871587  105.878930  ...   49.124135  0.851349\n",
              "\n",
              "[3 rows x 26 columns]"
            ]
          },
          "metadata": {},
          "execution_count": 21
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "eval_qt_env._df_exch[idx0:idx1]"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 143
        },
        "id": "dR1uB55bwWZu",
        "outputId": "0bb86a88-bd2f-4753-d446-6fe7a43a9319"
      },
      "execution_count": 22,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/html": [
              "\n",
              "  <div id=\"df-b376e387-3bd5-4429-8617-6c83228a6616\">\n",
              "    <div class=\"colab-df-container\">\n",
              "      <div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>SPY</th>\n",
              "      <th>QQQ</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>2012-01-03 15:59:59</th>\n",
              "      <td>127.500000</td>\n",
              "      <td>56.900002</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2012-01-04 15:59:59</th>\n",
              "      <td>127.699997</td>\n",
              "      <td>57.139999</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2012-01-05 15:59:59</th>\n",
              "      <td>128.039993</td>\n",
              "      <td>57.610001</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>\n",
              "      <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-b376e387-3bd5-4429-8617-6c83228a6616')\"\n",
              "              title=\"Convert this dataframe to an interactive table.\"\n",
              "              style=\"display:none;\">\n",
              "        \n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "       width=\"24px\">\n",
              "    <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
              "    <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
              "  </svg>\n",
              "      </button>\n",
              "      \n",
              "  <style>\n",
              "    .colab-df-container {\n",
              "      display:flex;\n",
              "      flex-wrap:wrap;\n",
              "      gap: 12px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert {\n",
              "      background-color: #E8F0FE;\n",
              "      border: none;\n",
              "      border-radius: 50%;\n",
              "      cursor: pointer;\n",
              "      display: none;\n",
              "      fill: #1967D2;\n",
              "      height: 32px;\n",
              "      padding: 0 0 0 0;\n",
              "      width: 32px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert:hover {\n",
              "      background-color: #E2EBFA;\n",
              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "      fill: #174EA6;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert {\n",
              "      background-color: #3B4455;\n",
              "      fill: #D2E3FC;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert:hover {\n",
              "      background-color: #434B5C;\n",
              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "      fill: #FFFFFF;\n",
              "    }\n",
              "  </style>\n",
              "\n",
              "      <script>\n",
              "        const buttonEl =\n",
              "          document.querySelector('#df-b376e387-3bd5-4429-8617-6c83228a6616 button.colab-df-convert');\n",
              "        buttonEl.style.display =\n",
              "          google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "        async function convertToInteractive(key) {\n",
              "          const element = document.querySelector('#df-b376e387-3bd5-4429-8617-6c83228a6616');\n",
              "          const dataTable =\n",
              "            await google.colab.kernel.invokeFunction('convertToInteractive',\n",
              "                                                     [key], {});\n",
              "          if (!dataTable) return;\n",
              "\n",
              "          const docLinkHtml = 'Like what you see? Visit the ' +\n",
              "            '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
              "            + ' to learn more about interactive tables.';\n",
              "          element.innerHTML = '';\n",
              "          dataTable['output_type'] = 'display_data';\n",
              "          await google.colab.output.renderOutput(dataTable, element);\n",
              "          const docLink = document.createElement('div');\n",
              "          docLink.innerHTML = docLinkHtml;\n",
              "          element.appendChild(docLink);\n",
              "        }\n",
              "      </script>\n",
              "    </div>\n",
              "  </div>\n",
              "  "
            ],
            "text/plain": [
              "                            SPY        QQQ\n",
              "2012-01-03 15:59:59  127.500000  56.900002\n",
              "2012-01-04 15:59:59  127.699997  57.139999\n",
              "2012-01-05 15:59:59  128.039993  57.610001"
            ]
          },
          "metadata": {},
          "execution_count": 22
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "At the end of 2012-01-03, if action is [0.4, 0.4] or 40% in SPY, 40% in QQQ, and the remaining 20% in cash, then we buy 40_000/127.50 or 313 shares of SPY, and 40_000/56.9=702 shares of QQQ. We pay commission of (313x127.50+702x56.9)x0.0001=7.985, and the remaining cash=100_000-313x127.50-702x56.9-7.985=20_140.7, roughly 20% of NAV.\n",
        "\n",
        "Then the market moves to 2012-01-04, and SPY price goes up to 127.70, QQQ goes up to 57.14. They are now worth 313x127.70 and 702x57.14, respectively, and NAV including cash becomes 313x127.70+702x57.14+20_140.7=100_233.09.\n",
        "\n",
        "NAV change is the reward, in this case is 233.09.\n",
        "\n",
        "As shown below."
      ],
      "metadata": {
        "id": "6iMXKY9kwcU2"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "eval_qt_env._df_positions.iloc[idx0]"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "UgvsFaTuwZEN",
        "outputId": "23f5b6bb-87b9-4f35-f5a7-a4b48af028f0"
      },
      "execution_count": 26,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "SPY          0.0\n",
              "QQQ          0.0\n",
              "Cash    100000.0\n",
              "NAV     100000.0\n",
              "Name: 2012-01-03 15:59:59, dtype: float64"
            ]
          },
          "metadata": {},
          "execution_count": 26
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "eval_qt_env._df_positions.iloc[idx0+1]"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "y34Oo4lKweRN",
        "outputId": "e1747298-4fe5-4f4a-f732-7b8699fab62f"
      },
      "execution_count": 27,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "SPY        313.000000\n",
              "QQQ        702.000000\n",
              "Cash     20140.713799\n",
              "NAV     100223.092415\n",
              "Name: 2012-01-04 15:59:59, dtype: float64"
            ]
          },
          "metadata": {},
          "execution_count": 27
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "reward,  100223.092415-100000.0"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "zIYVbnpqwf6m",
        "outputId": "d284015a-ddd1-4c73-9d7f-e403e0b91e2c"
      },
      "execution_count": 28,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(223.09241505889892, 223.09241500000644)"
            ]
          },
          "metadata": {},
          "execution_count": 28
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Create TF-Agents environment from Gym environment."
      ],
      "metadata": {
        "id": "x7t8ZaBgwoH-"
      }
    },
    {
      "cell_type": "code",
      "source": [
        ""
      ],
      "metadata": {
        "id": "sxL7vtpKwkI2"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}